import cv2
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset


class ExpressionDataset(Dataset):
    """
    Testing Dataset Adapter for Torch
    """

    def __init__(self, test_set, transform=None):
        """
        :param test_set: Path to the test index
        :param transform: Transformer that will be applied on the image data
        """
        dataset = pd.read_csv(test_set, sep=',', header=None)

        self.img_paths = dataset[2].to_list()
        self.img_name = dataset[1].to_list()
        self.shot_id = dataset[0].to_list()

        self.transform = transform

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        raw_image = cv2.imread(self.img_paths[idx])
        image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(image)
        if self.transform:
            image = self.transform(image)

        return image, self.img_name[idx], self.shot_id[idx]
